1. Introduction

This analysis explores EEG spectrogram embeddings to predict expert consensus classifications. We’ll implement:

  1. Dimensionality reduction techniques suitable for embedding vectors
  2. Three machine learning models with extensive hyperparameter tuning:
    • LightGBM (gradient boosting)
    • Random Forest with feature importance
    • Support Vector Machine with RBF kernel
  3. Comprehensive performance evaluation including confusion matrices, F1 scores, and AUC metrics

2. Environment Setup

# Function to install missing packages
install_if_missing <- function(packages) {
  new_packages <- packages[!(packages %in% installed.packages()[,"Package"])]
  if(length(new_packages)) {
    repos <- "https://cloud.r-project.org"
    install.packages(new_packages, dependencies = TRUE, repos = repos)
  }
}

# List of required packages
required_packages <- c("tidyverse", "caret", "ranger", "lightgbm", "e1071", 
                      "umap", "ggplot2", "corrplot", "doParallel", "pROC", 
                      "reshape2", "viridis", "gridExtra", "nnet") # Add nnet for multinom

# Install missing packages
install_if_missing(required_packages)

# Load necessary libraries
library(tidyverse)
library(caret)
library(ranger)
library(lightgbm)
library(e1071)
library(umap)
library(ggplot2)
library(corrplot)
library(doParallel)
library(pROC)
library(reshape2)
library(viridis)
library(gridExtra)

# Set seed for reproducibility
set.seed(42)

# Set up parallel processing
registerDoParallel(cores = parallel::detectCores() - 1)

3. Data Loading and Exploration

# Using file.path for OS-independent path handling
file_path <- file.path("..", "out", "test", "vit_embeddings_test.csv")
data <- read.csv(file_path)  # Assign to 'data' variable

# Print successful data load
cat("Successfully loaded data with", nrow(data), "rows and", ncol(data), "columns\n")
## Successfully loaded data with 2136 rows and 774 columns
# Basic data information
str(data, list.len = 10)
## 'data.frame':    2136 obs. of  774 variables:
##  $ spectrogram_id    : int  686402130 959372535 250501602 1872858502 1540613004 1606814511 1694525835 1845795531 1993693261 38177271 ...
##  $ spectrogram_sub_id: int  30 5 8 1 7 0 5 2 36 12 ...
##  $ eeg_id            : num  3.53e+08 1.76e+09 3.69e+08 1.84e+09 2.79e+09 ...
##  $ eeg_sub_id        : int  30 5 8 1 7 0 5 2 36 12 ...
##  $ offset_seconds    : num  108 18 44 4 34 0 12 22 178 26 ...
##  $ expert_consensus  : chr  "GPD" "GRDA" "Seizure" "LPD" ...
##  $ embedding_0       : num  -0.186 0.534 0.611 0.503 0.359 ...
##  $ embedding_1       : num  -0.92 -0.552 0.656 0.861 1.12 ...
##  $ embedding_2       : num  0.0599 -1.1648 -1.043 -0.7725 0.172 ...
##  $ embedding_3       : num  0.368583 -0.200468 -0.045985 -0.198015 -0.000211 ...
##   [list output truncated]
# Check for missing values
missing_values <- colSums(is.na(data))
if(sum(missing_values) > 0) {
  cat("Missing values found in", sum(missing_values > 0), "columns\n")
} else {
  cat("No missing values found in the dataset\n")
}
## No missing values found in the dataset
# Convert target to factor
data$expert_consensus <- as.factor(data$expert_consensus)

# Target distribution
target_dist <- table(data$expert_consensus)
print("Class distribution:")
## [1] "Class distribution:"
print(target_dist)
## 
##     GPD    GRDA     LPD    LRDA   Other Seizure 
##     321     373     297     324     391     430
# Visualize class distribution
barplot_data <- data.frame(
  Class = names(target_dist),
  Count = as.numeric(target_dist)
)

ggplot(barplot_data, aes(x = Class, y = Count, fill = Class)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = Count), vjust = -0.5) +
  scale_fill_viridis_d() +
  labs(title = "Distribution of Expert Consensus Classes",
       x = "Class", y = "Count") +
  theme_minimal() +
  theme(legend.position = "none")

4. Data Preprocessing and Feature Analysis

# Extract features and target
embedding_cols <- grep("embedding_", names(data), value = TRUE)
X <- data[, embedding_cols]
y <- data$expert_consensus

# Split data into training and testing sets (70/30)
set.seed(42)
train_index <- createDataPartition(y, p = 0.7, list = FALSE)
X_train <- X[train_index, ]
X_test <- X[-train_index, ]
y_train <- y[train_index]
y_test <- y[-train_index]

# Check embedding dimension
cat("Embedding dimension:", ncol(X), "\n")
## Embedding dimension: 768
# Basic feature statistics
embedding_stats <- data.frame(
  mean = colMeans(X_train),
  std = apply(X_train, 2, sd),
  min = apply(X_train, 2, min),
  max = apply(X_train, 2, max),
  median = apply(X_train, 2, median)
)

# Plot distribution statistics
par(mfrow = c(2, 2))
hist(embedding_stats$mean, main = "Embedding Mean Distribution", col = "skyblue")
hist(embedding_stats$std, main = "Embedding Std Distribution", col = "lightgreen")
hist(embedding_stats$min, main = "Embedding Min Values", col = "salmon")
hist(embedding_stats$max, main = "Embedding Max Values", col = "plum")

par(mfrow = c(1, 1))

# Sample correlation matrix (for a subset of embeddings)
set.seed(123)
sample_size <- min(30, length(embedding_cols))
sample_embeddings <- sample(embedding_cols, sample_size)
corr_matrix <- cor(X_train[, sample_embeddings])

# Visualize correlation
corrplot(corr_matrix, method = "circle", type = "upper", 
         tl.cex = 0.6, tl.col = "black", tl.srt = 45,
         title = "Sample Embedding Correlation")

# Apply UMAP for visualization
set.seed(42)
umap_result <- umap(X_train, n_components = 2, n_neighbors = 15, min_dist = 0.1)

# Create visualization data
umap_df <- data.frame(
  UMAP1 = umap_result$layout[,1],
  UMAP2 = umap_result$layout[,2],
  Class = y_train
)

# Plot UMAP visualization
ggplot(umap_df, aes(x = UMAP1, y = UMAP2, color = Class)) +
  geom_point(alpha = 0.7) +
  scale_color_viridis_d() +
  labs(title = "UMAP Visualization of EEG Embedding Space",
       subtitle = "Colored by Expert Consensus Class") +
  theme_minimal()

# Apply PCA for dimensionality reduction (for LR)
pca_result <- prcomp(X_train, center = TRUE, scale. = TRUE)

# Calculate variance explained
variance_explained <- summary(pca_result)$importance[2,]
cumulative_variance <- summary(pca_result)$importance[3,]

# Find number of components for 95% variance
n_components_95 <- which(cumulative_variance >= 0.95)[1]
cat("Number of PCA components for 95% variance:", n_components_95, "\n")
## Number of PCA components for 95% variance: 118
# Plot variance explained
plot(cumulative_variance, type = "b", xlab = "Principal Component", 
     ylab = "Cumulative Proportion of Variance", 
     main = "PCA Cumulative Variance")
abline(h = 0.95, col = "red", lty = 2)

# Create PCA-transformed data for LR
X_train_pca <- predict(pca_result, X_train)[, 1:n_components_95]
X_test_pca <- predict(pca_result, X_test)[, 1:n_components_95]

5. Model 1: LightGBM with Hyperparameter Tuning

# Prepare data for LightGBM
num_classes <- length(levels(y_train))
y_train_numeric <- as.integer(y_train) - 1  # Convert to 0-based index

dtrain <- lgb.Dataset(
  data = as.matrix(X_train),
  label = y_train_numeric
)

# Define parameter grid
lgb_params <- list(
  objective = "multiclass",
  metric = "multi_logloss",
  num_class = num_classes,
  verbose = -1,
  feature_pre_filter = FALSE  # Add this line to fix the error
)

# Function to evaluate a parameter set with CV
evaluate_params <- function(params, nrounds = 100, nfold = 5) {
  full_params <- c(lgb_params, params)
  
  cv_result <- lgb.cv(
    params = full_params,
    data = dtrain,
    nrounds = nrounds,
    nfold = nfold,
    early_stopping_rounds = 20,
    verbose = 0
  )
  
  # Return best score and iteration
  best_iter <- which.min(cv_result$best_iter)
  return(list(
    score = min(cv_result$best_score),
    iteration = best_iter
  ))
}

# Grid of parameters to try
param_grid <- list(
  list(num_leaves = 31, learning_rate = 0.1, max_depth = 6, 
       feature_fraction = 0.8, bagging_fraction = 0.8, min_data_in_leaf = 20),
  list(num_leaves = 63, learning_rate = 0.05, max_depth = 8, 
       feature_fraction = 0.7, bagging_fraction = 0.7, min_data_in_leaf = 30),
  list(num_leaves = 127, learning_rate = 0.01, max_depth = 10, 
       feature_fraction = 0.9, bagging_fraction = 0.9, min_data_in_leaf = 10)
)

# Find best parameters
cat("Tuning LightGBM hyperparameters...\n")
## Tuning LightGBM hyperparameters...
lgb_results <- list()
best_score <- Inf
best_params <- NULL
best_iter <- 100

for (i in 1:length(param_grid)) {
  cat("Evaluating parameter set", i, "of", length(param_grid), "\n")
  result <- evaluate_params(param_grid[[i]])
  
  lgb_results[[i]] <- list(
    params = param_grid[[i]],
    score = result$score,
    iteration = result$iteration
  )
  
  if (result$score < best_score) {
    best_score <- result$score
    best_params <- param_grid[[i]]
    best_iter <- result$iteration
  }
}
## Evaluating parameter set 1 of 3 
## Evaluating parameter set 2 of 3 
## Evaluating parameter set 3 of 3
# Print best parameters
cat("\nBest LightGBM parameters:\n")
## 
## Best LightGBM parameters:
print(best_params)
## $num_leaves
## [1] 63
## 
## $learning_rate
## [1] 0.05
## 
## $max_depth
## [1] 8
## 
## $feature_fraction
## [1] 0.7
## 
## $bagging_fraction
## [1] 0.7
## 
## $min_data_in_leaf
## [1] 30
cat("Best CV score:", best_score, "\n")
## Best CV score: 1.364927
# Train final LightGBM model with best parameters
final_params <- c(lgb_params, best_params)

set.seed(42)
lgb_model <- lgb.train(
  params = final_params,
  data = dtrain,
  nrounds = 100,  # Typically you'd use best_iter, setting to 100 for simplicity
  verbose = 0
)

# Feature importance
lgb_importance <- lgb.importance(lgb_model, percentage = TRUE)
lgb_imp_plot <- lgb.plot.importance(lgb_importance, top_n = 20, measure = "Gain")

# Make predictions
lgb_probs <- predict(lgb_model, as.matrix(X_test))
lgb_prob_matrix <- matrix(lgb_probs, ncol = num_classes, byrow = TRUE)
lgb_preds <- max.col(lgb_prob_matrix) - 1  # Convert to 0-based class indices
lgb_preds <- factor(lgb_preds, levels = 0:(num_classes-1))  # Convert to factor

# Convert y_test to same factor levels for comparison
y_test_0idx <- as.integer(y_test) - 1  # Convert to 0-based index
y_test_factor <- factor(y_test_0idx, levels = 0:(num_classes-1))

# Calculate confusion matrix
lgb_cm <- confusionMatrix(lgb_preds, y_test_factor)
print("LightGBM Performance:")
## [1] "LightGBM Performance:"
print(lgb_cm$overall)
##       Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull 
##    0.167449139    0.001565432    0.139307911    0.198702152    0.201877934 
## AccuracyPValue  McnemarPValue 
##    0.988080707    0.131046639

6. Model 2: Random Forest with Feature Selection

# Convert response to factor
train_data <- data.frame(X_train, class = y_train)
test_data <- data.frame(X_test, class = y_test)

# Define control parameters for tuning
control <- trainControl(
  method = "cv",
  number = 5,
  search = "grid",
  classProbs = TRUE,
  verboseIter = FALSE
)

# Define parameter grid for random forest
rf_grid <- expand.grid(
  mtry = c(floor(sqrt(ncol(X_train))), floor(ncol(X_train)/3), floor(ncol(X_train)/5)),
  splitrule = c("gini", "extratrees"),
  min.node.size = c(1, 5, 10)
)

# Train Random Forest
cat("Training Random Forest model...\n")
## Training Random Forest model...
set.seed(42)

rf_model <- train(
  x = X_train,  
  y = y_train,  # Make sure y_train is a factor
  method = "ranger",
  trControl = control,
  tuneGrid = rf_grid,
  importance = "impurity",
  num.trees = 500
)
# Display tuning results
print(rf_model)
## Random Forest 
## 
## 1497 samples
##  768 predictor
##    6 classes: 'GPD', 'GRDA', 'LPD', 'LRDA', 'Other', 'Seizure' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 1199, 1197, 1197, 1197, 1198 
## Resampling results across tuning parameters:
## 
##   mtry  splitrule   min.node.size  Accuracy   Kappa    
##    27   gini         1             0.4822965  0.3723073
##    27   gini         5             0.4802898  0.3696975
##    27   gini        10             0.4822831  0.3719814
##    27   extratrees   1             0.4782898  0.3667707
##    27   extratrees   5             0.4802831  0.3690726
##    27   extratrees  10             0.4789609  0.3674337
##   153   gini         1             0.4829855  0.3731245
##   153   gini         5             0.4816477  0.3715242
##   153   gini        10             0.4762830  0.3649651
##   153   extratrees   1             0.4749564  0.3634977
##   153   extratrees   5             0.4742786  0.3622398
##   153   extratrees  10             0.4742942  0.3621538
##   256   gini         1             0.4876499  0.3788775
##   256   gini         5             0.4689743  0.3563254
##   256   gini        10             0.4816343  0.3713702
##   256   extratrees   1             0.4749632  0.3631573
##   256   extratrees   5             0.4769475  0.3656100
##   256   extratrees  10             0.4776164  0.3661686
## 
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 256, splitrule = gini
##  and min.node.size = 1.
plot(rf_model)

# Variable importance
rf_importance <- varImp(rf_model)
plot(rf_importance, top = 20, main = "Random Forest: Top 20 Variables")

# Make predictions
rf_preds <- predict(rf_model, test_data)
rf_probs <- predict(rf_model, test_data, type = "prob")

# Calculate confusion matrix
rf_cm <- confusionMatrix(rf_preds, y_test)
print("Random Forest Performance:")
## [1] "Random Forest Performance:"
print(rf_cm$overall)
##       Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull 
##   5.023474e-01   3.959477e-01   4.628639e-01   5.418092e-01   2.018779e-01 
## AccuracyPValue  McnemarPValue 
##   5.891456e-64   8.036481e-11

7. Model 3: Logistic Regression

# Prepare data for Logistic Regression - using PCA reduced data for computational efficiency
train_data_pca <- data.frame(X_train_pca, class = y_train)
test_data_pca <- data.frame(X_test_pca, class = y_test)

# Define parameter grid for Logistic Regression
lr_grid <- expand.grid(
  decay = c(0, 0.001, 0.01, 0.1)  # Regularization parameter
)

# Train Logistic Regression
cat("Training Logistic Regression model...\n")
## Training Logistic Regression model...
set.seed(42)
lr_model <- train(
  class ~ .,
  data = train_data_pca,
  method = "multinom",  # Multinomial logistic regression
  trControl = control,
  tuneGrid = lr_grid,
  trace = FALSE,  # Suppress convergence messages
  MaxNWts = 5000  # Increase max weights for high-dimensional data
)

# Display tuning results
print(lr_model)
## Penalized Multinomial Regression 
## 
## 1497 samples
##  118 predictor
##    6 classes: 'GPD', 'GRDA', 'LPD', 'LRDA', 'Other', 'Seizure' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 1199, 1197, 1197, 1197, 1198 
## Resampling results across tuning parameters:
## 
##   decay  Accuracy   Kappa    
##   0.000  0.4048023  0.2828969
##   0.001  0.4054689  0.2837044
##   0.010  0.4054689  0.2837044
##   0.100  0.4081490  0.2870644
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was decay = 0.1.
plot(lr_model)

# Make predictions
lr_preds <- predict(lr_model, test_data_pca)
lr_probs <- predict(lr_model, test_data_pca, type = "prob")

# Calculate confusion matrix
lr_cm <- confusionMatrix(lr_preds, y_test)
print("Logistic Regression Performance:")
## [1] "Logistic Regression Performance:"
print(lr_cm$overall)
##       Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull 
##   4.303599e-01   3.124930e-01   3.915864e-01   4.697798e-01   2.018779e-01 
## AccuracyPValue  McnemarPValue 
##   3.882984e-39   4.922007e-05
# Variable importance if available
if (any(grepl("varImp", methods(class = class(lr_model)[1])))) {
  lr_importance <- varImp(lr_model)
  plot(lr_importance, top = 20, main = "Logistic Regression: Top 20 Variables")
}

8. Model Comparison and Evaluation

# Collect performance metrics
models <- c("LightGBM", "Random Forest", "LR")
accuracy <- c(lgb_cm$overall["Accuracy"], rf_cm$overall["Accuracy"], lr_cm$overall["Accuracy"])
kappa <- c(lgb_cm$overall["Kappa"], rf_cm$overall["Kappa"], lr_cm$overall["Kappa"])

# Calculate F1 score for each class
calculate_f1 <- function(cm) {
  f1_scores <- cm$byClass[, "F1"]
  if (is.matrix(f1_scores)) {
    return(diag(f1_scores))
  } else {
    return(f1_scores)
  }
}

f1_lgb <- calculate_f1(lgb_cm)
f1_rf <- calculate_f1(rf_cm)
f1_lr <- calculate_f1(lr_cm)

# Mean F1 scores
mean_f1 <- c(mean(f1_lgb, na.rm = TRUE), 
             mean(f1_rf, na.rm = TRUE), 
             mean(f1_lr, na.rm = TRUE))

# Create comparison table
metrics_df <- data.frame(
  Model = models,
  Accuracy = accuracy,
  Kappa = kappa,
  Mean_F1 = mean_f1
)
# Better function to plot confusion matrix
plot_cm <- function(cm, title) {
  # Get the data from confusion matrix
  cm_table <- as.data.frame(cm$table)
  names(cm_table) <- c("Reference", "Prediction", "Freq")
  
  # Ensure the class labels are properly ordered factors
  all_levels <- levels(as.factor(c(as.character(cm_table$Reference), 
                                 as.character(cm_table$Prediction))))
  cm_table$Reference <- factor(cm_table$Reference, levels = all_levels)
  cm_table$Prediction <- factor(cm_table$Prediction, levels = all_levels)
  
  # Calculate percentages by row (true class)
  cm_table <- cm_table %>%
    group_by(Reference) %>%
    mutate(Total = sum(Freq),
           Percentage = Freq / Total * 100) %>%
    ungroup()
  
  # Create the plot with cleaner formatting
  ggplot(cm_table, aes(x = Prediction, y = Reference, fill = Percentage)) +
    geom_tile() +
    # Display both count and percentage
    geom_text(aes(label = sprintf("%d\n(%.1f%%)", Freq, Percentage)), 
              color = "black", size = 3) +
    # Use a better color gradient
    scale_fill_gradient2(low = "white", high = "darkblue", mid = "skyblue",
                        midpoint = 50, limits = c(0, 100)) +
    # Better labels
    labs(title = title,
         x = "Predicted Class", 
         y = "True Class") +
    # Cleaner theme
    theme_minimal() +
    # Make sure axes labels are readable
    theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 10),
          axis.text.y = element_text(size = 10),
          plot.title = element_text(hjust = 0.5, size = 14))
}

# Plot confusion matrices
lgb_cm_plot <- plot_cm(lgb_cm, "LightGBM Confusion Matrix")
rf_cm_plot <- plot_cm(rf_cm, "Random Forest Confusion Matrix")
lr_cm_plot <- plot_cm(lr_cm, "LR Confusion Matrix")

# Display side by side
grid.arrange(lgb_cm_plot, nrow = 1)

grid.arrange(rf_cm_plot, nrow = 1)

grid.arrange(lr_cm_plot, nrow = 1)

# Compare F1 scores by class
f1_by_class <- data.frame(
  Class = factor(levels(y_test)),
  LightGBM = f1_lgb,
  RandomForest = f1_rf,
  LR = f1_lr
)

# Format for plotting
f1_long <- reshape2::melt(f1_by_class, id.vars = "Class", 
                         variable.name = "Model", value.name = "F1_Score")

# Plot F1 scores by class
ggplot(f1_long, aes(x = Class, y = F1_Score, fill = Model)) +
  geom_bar(stat = "identity", position = position_dodge()) +
  geom_text(aes(label = sprintf("%.2f", F1_Score)), 
            position = position_dodge(width = 0.9), 
            vjust = -0.5, size = 3) +
  scale_fill_viridis_d() +
  labs(title = "F1 Scores by Class and Model",
       x = "Class", y = "F1 Score") +
  theme_minimal() +
  ylim(0, 1)

# Function to calculate ROC curves and AUC for multiclass
calculate_multiclass_roc <- function(probs, true_class, model_name) {
  # Convert labels if needed
  if (is.factor(true_class)) {
    true_class <- as.integer(true_class) - 1
  }
  
  # Vectors to store results
  auc_values <- numeric(num_classes)
  
  # Calculate for each class
  par(mfrow = c(1, 1))
  colors <- rainbow(num_classes)
  
  for (i in 1:num_classes) {
    # Binary classification: class i vs. rest
    binary_true <- ifelse(true_class == i-1, 1, 0)
    class_probs <- probs[, i]
    
    # Calculate ROC
    roc_obj <- roc(binary_true, class_probs)
    auc_values[i] <- auc(roc_obj)
    
    # Plot first curve, then add others
    if (i == 1) {
      plot(roc_obj, col = colors[i], 
           main = paste(model_name, "ROC Curves"),
           lwd = 2)
    } else {
      plot(roc_obj, col = colors[i], add = TRUE, lwd = 2)
    }
  }
  
  # Add legend
  legend("bottomright", 
         legend = paste("Class", 0:(num_classes-1), "AUC =", round(auc_values, 3)),
         col = colors, lwd = 2)
  
  # Return AUC values
  mean_auc <- mean(auc_values)
  cat(model_name, "Mean AUC:", mean_auc, "\n")
  return(list(auc = auc_values, mean_auc = mean_auc))
}

# Calculate ROC curves for each model
par(mfrow = c(1, 3))

# LightGBM
lgb_roc <- calculate_multiclass_roc(lgb_prob_matrix, y_test_0idx, "LightGBM")

## LightGBM Mean AUC: 0.5211887
# Random Forest
rf_roc <- calculate_multiclass_roc(as.matrix(rf_probs), y_test, "Random Forest")

## Random Forest Mean AUC: 0.8066222
# Logistic Regression
lr_roc <- calculate_multiclass_roc(as.matrix(lr_probs), y_test, "Logistic Regression")

## Logistic Regression Mean AUC: 0.7333402
par(mfrow = c(1, 1))

# Add AUC to metrics
metrics_df$Mean_AUC <- c(lgb_roc$mean_auc, rf_roc$mean_auc, lr_roc$mean_auc)
print(metrics_df)
##           Model  Accuracy       Kappa   Mean_F1  Mean_AUC
## 1      LightGBM 0.1674491 0.001565432 0.1658995 0.5211887
## 2 Random Forest 0.5023474 0.395947657 0.5087884 0.8066222
## 3            LR 0.4303599 0.312492980 0.4298563 0.7333402

9. Best Model and Feature Importance

# Find best model
best_idx <- which.max(metrics_df$Accuracy)
best_model <- metrics_df$Model[best_idx]

cat("Best model by accuracy:", best_model, "\n")
## Best model by accuracy: Random Forest
cat("Accuracy:", metrics_df$Accuracy[best_idx], "\n")
## Accuracy: 0.5023474
cat("Kappa:", metrics_df$Kappa[best_idx], "\n")
## Kappa: 0.3959477
cat("Mean F1:", metrics_df$Mean_F1[best_idx], "\n")
## Mean F1: 0.5087884
cat("Mean AUC:", metrics_df$Mean_AUC[best_idx], "\n")
## Mean AUC: 0.8066222
# Get best confusion matrix
best_cm <- switch(best_model,
                 "LightGBM" = lgb_cm,
                 "Random Forest" = rf_cm,
                 "LR" = lr_cm)

# Enhanced visualization of best model's confusion matrix
cm_df <- as.data.frame(best_cm$table)
names(cm_df) <- c("Reference", "Prediction", "Freq")

# Calculate percentages
cm_df <- cm_df %>%
  group_by(Reference) %>%
  mutate(
    Total = sum(Freq),
    Percentage = Freq / Total * 100,
    Label = sprintf("%d\n(%.1f%%)", Freq, Percentage)
  )

# Plot enhanced confusion matrix
ggplot(cm_df, aes(x = Prediction, y = Reference, fill = Percentage)) +
  geom_tile() +
  geom_text(aes(label = Label), color = "black") +
  scale_fill_gradient2(low = "white", high = "darkblue", mid = "skyblue", 
                     midpoint = 50, limits = c(0, 100)) +
  labs(title = paste("Best Model:", best_model, "- Confusion Matrix"),
       subtitle = paste("Overall Accuracy:", round(metrics_df$Accuracy[best_idx] * 100, 2), "%"),
       x = "Predicted Class", y = "True Class") +
  theme_minimal()

# Analyze feature importance from the best model
if (best_model == "LightGBM") {
  # Plot top features
  lgb.plot.importance(lgb_importance, top_n = 20, measure = "Gain")
  
} else if (best_model == "Random Forest") {
  # Plot RF importance
  plot(varImp(rf_model), top = 20)
}

# For all models, analyze feature importance patterns
if (best_model == "LightGBM" || best_model == "Random Forest") {
  # Get importance data
  if (best_model == "LightGBM") {
    importance_data <- lgb_importance
    importance_col <- "Gain"
  } else {
    importance_data <- varImp(rf_model)$importance
    importance_data$Overall <- importance_data$Overall / sum(importance_data$Overall) * 100
    importance_data$Variable <- rownames(importance_data)
    importance_col <- "Overall"
  }
  
  # Identify patterns in top embeddings
  top_n <- 50  # Number of top features to analyze
  if (nrow(importance_data) >= top_n) {
    top_features <- importance_data[1:top_n,]
    
    # Extract embedding indices
    extract_index <- function(feature_name) {
      # Extract number after "embedding_"
      idx <- as.numeric(gsub("embedding_", "", feature_name))
      return(idx)
    }
    
    if (best_model == "LightGBM") {
      embedding_indices <- sapply(top_features$Feature, extract_index)
      top_features$Index <- embedding_indices
    } else {
      embedding_indices <- sapply(top_features$Variable, extract_index)
      top_features$Index <- embedding_indices
    }
    
    # Plot distribution of important embedding indices
    if (best_model == "LightGBM") {
      hist(top_features$Index, breaks = 20, 
           main = "Distribution of Top Embedding Indices", 
           xlab = "Embedding Index", col = "skyblue")
    } else {
      hist(top_features$Index, breaks = 20, 
           main = "Distribution of Top Embedding Indices", 
           xlab = "Embedding Index", col = "skyblue")
    }
  }
}

9. Performance Comparison

# Collect performance metrics
models <- c("Random Forest", "LightGBM")
accuracy <- c(
  rf_cm$overall["Accuracy"],
  lgb_cm$overall["Accuracy"]
)
kappa <- c(
  rf_cm$overall["Kappa"],
  lgb_cm$overall["Kappa"]
)

# Create comparison dataframe
model_performance <- data.frame(
  Model = models,
  Accuracy = accuracy,
  Kappa = kappa
)

# Display table of performance metrics
print("Performance comparison:")
## [1] "Performance comparison:"
print(model_performance)
##           Model  Accuracy       Kappa
## 1 Random Forest 0.5023474 0.395947657
## 2      LightGBM 0.1674491 0.001565432
# Plot comparison
ggplot(model_performance, aes(x = Model, y = Accuracy, fill = Model)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = sprintf("%.4f", Accuracy)), vjust = -0.5) +
  scale_fill_viridis_d() +
  labs(title = "Model Accuracy Comparison",
       y = "Accuracy") +
  theme_minimal() +
  ylim(0, 1)

10. Error Analysis

# Get predictions from best model
best_preds <- switch(best_model,
                    "LightGBM" = lgb_preds,
                    "Random Forest" = rf_preds,
                    "LR" = lr_preds)

# Convert both to character first to ensure matching factors
best_preds_char <- as.character(best_preds)
y_test_char <- as.character(y_test)

# Get all unique classes
all_classes <- unique(c(best_preds_char, y_test_char))

# Convert back to factors with the same levels
best_preds_factor <- factor(best_preds_char, levels = all_classes)
y_test_factor <- factor(y_test_char, levels = all_classes)

# Now find misclassified instances
misclassified <- which(best_preds_factor != y_test_factor)

# If there are any misclassified instances
if (length(misclassified) > 0) {
  misclass_pairs <- data.frame(
    True = y_test_factor[misclassified],
    Predicted = best_preds_factor[misclassified]
  )
  
  # Count misclassification patterns
  misclass_counts <- misclass_pairs %>%
    group_by(True, Predicted) %>%
    summarise(Count = n(), .groups = 'drop') %>%
    arrange(desc(Count))
  
  # Display top misclassification patterns
  print("Top Misclassification Patterns:")
  print(head(misclass_counts, 10))
  
  # Visualize misclassification patterns if there are any
  if (nrow(misclass_counts) > 0) {
    # Take only up to 10 rows, but handle case where there are fewer
    plot_data <- head(misclass_counts, min(10, nrow(misclass_counts)))
    
    ggplot(plot_data, aes(x = paste(True, "→", Predicted), y = Count, fill = True)) +
      geom_bar(stat = "identity") +
      geom_text(aes(label = Count), vjust = -0.5) +
      scale_fill_viridis_d() +
      labs(title = "Top Misclassification Patterns",
           subtitle = paste("Based on", best_model, "predictions"),
           x = "Misclassification (True → Predicted)", y = "Count") +
      theme_minimal() +
      theme(axis.text.x = element_text(angle = 45, hjust = 1))
  } else {
    cat("No visualization: too few misclassification patterns\n")
  }
} else {
  cat("No misclassified instances found!\n")
}
## [1] "Top Misclassification Patterns:"
## # A tibble: 10 × 3
##    True    Predicted Count
##    <fct>   <fct>     <int>
##  1 Other   Seizure      30
##  2 Seizure Other        24
##  3 LPD     Other        23
##  4 GRDA    Other        22
##  5 LRDA    Other        22
##  6 GRDA    Seizure      21
##  7 GPD     GRDA         21
##  8 Other   GRDA         18
##  9 LRDA    Seizure      17
## 10 LPD     GRDA         17

11. Prediction Confidence Analysis

# Get probability predictions from best model
best_probs <- switch(best_model,
                    "LightGBM" = lgb_prob_matrix,
                    "Random Forest" = as.matrix(rf_probs),
                    "LR" = as.matrix(lr_probs))

# Calculate confidence (max probability) for each prediction
confidence <- apply(best_probs, 1, max)

# Create data frame with prediction results
confidence_df <- data.frame(
  True_Class = y_test_factor,
  Predicted_Class = best_preds,
  Confidence = confidence,
  Correct = (best_preds == y_test_factor)
)

# Analyze confidence distribution for correct vs incorrect predictions
ggplot(confidence_df, aes(x = Confidence, fill = Correct)) +
  geom_density(alpha = 0.7) +
  scale_fill_manual(values = c("FALSE" = "firebrick", "TRUE" = "forestgreen")) +
  labs(title = "Prediction Confidence Distribution",
       subtitle = "Comparing correct vs. incorrect predictions",
       x = "Confidence (Maximum Probability)", y = "Density") +
  theme_minimal()

# Analyze accuracy at different confidence thresholds
thresholds <- seq(0, 1, by = 0.05)
threshold_results <- data.frame(
  Threshold = thresholds,
  Accuracy = sapply(thresholds, function(t) {
    filtered <- confidence_df[confidence_df$Confidence >= t, ]
    if (nrow(filtered) > 0) {
      return(mean(filtered$Correct))
    } else {
      return(NA)
    }
  }),
  Coverage = sapply(thresholds, function(t) {
    sum(confidence_df$Confidence >= t) / nrow(confidence_df)
  })
)

# Plot accuracy vs threshold
ggplot(threshold_results, aes(x = Threshold)) +
  geom_line(aes(y = Accuracy, color = "Accuracy"), size = 1) +
  geom_line(aes(y = Coverage, color = "Coverage"), size = 1) +
  scale_color_manual(values = c("Accuracy" = "forestgreen", "Coverage" = "steelblue")) +
  labs(title = "Accuracy and Coverage vs. Confidence Threshold",
       x = "Confidence Threshold", y = "Value") +
  theme_minimal()

# Find optimal threshold for high accuracy while maintaining reasonable coverage
optimal_idx <- which.max(threshold_results$Accuracy * threshold_results$Coverage)
optimal_threshold <- threshold_results$Threshold[optimal_idx]

cat("Optimal confidence threshold:", optimal_threshold, "\n")
## Optimal confidence threshold: 0
cat("  - Accuracy at this threshold:", threshold_results$Accuracy[optimal_idx], "\n")
##   - Accuracy at this threshold: 0.5023474
cat("  - Coverage at this threshold:", threshold_results$Coverage[optimal_idx], "\n")
##   - Coverage at this threshold: 1

12. Per-Class Analysis

# Analyze per-class performance metrics
class_metrics <- data.frame(
  Class = levels(y_test_factor),
  Precision = best_cm$byClass[, "Precision"],
  Recall = best_cm$byClass[, "Recall"],
  F1 = best_cm$byClass[, "F1"],
  Specificity = best_cm$byClass[, "Specificity"]
)

# Calculate class-specific confusion
class_errors <- list()
for (class_idx in 1:num_classes) {
  class_label <- levels(y_test_factor)[class_idx]
  
  # Select instances of this class
  class_instances <- which(y_test_factor == class_label)
  
  # Confusion with other classes
  errors <- table(best_preds[class_instances])
  class_errors[[class_label]] <- errors
}

# Print per-class metrics
print("Per-class performance metrics:")
## [1] "Per-class performance metrics:"
print(class_metrics)
##                  Class Precision    Recall        F1 Specificity
## Class: GPD       Other 0.7763158 0.6145833 0.6860465   0.9686924
## Class: GRDA       GRDA 0.4236111 0.5495495 0.4784314   0.8428030
## Class: LPD     Seizure 0.6222222 0.3146067 0.4179104   0.9690909
## Class: LRDA       LRDA 0.7454545 0.4226804 0.5394737   0.9741697
## Class: Other       LPD 0.3701299 0.4871795 0.4206642   0.8141762
## Class: Seizure     GPD 0.4545455 0.5813953 0.5102041   0.8235294
# Plot per-class metrics
metrics_long <- reshape2::melt(class_metrics, id.vars = "Class", 
                              variable.name = "Metric", value.name = "Value")

ggplot(metrics_long, aes(x = Class, y = Value, fill = Metric)) +
  geom_bar(stat = "identity", position = position_dodge()) +
  geom_text(aes(label = sprintf("%.2f", Value)), 
            position = position_dodge(width = 0.9), 
            vjust = -0.5, size = 2.5) +
  scale_fill_viridis_d() +
  labs(title = "Performance Metrics by Class",
       subtitle = paste("Using", best_model),
       x = "Class", y = "Score") +
  theme_minimal() +
  ylim(0, 1)

# Identify the most challenging class
worst_class_idx <- which.min(class_metrics$F1)
worst_class <- class_metrics$Class[worst_class_idx]

cat("Most challenging class:", worst_class, "\n")
## Most challenging class: Seizure
cat("F1 score:", class_metrics$F1[worst_class_idx], "\n")
## F1 score: 0.4179104
cat("Confusion pattern:\n")
## Confusion pattern:
print(class_errors[[worst_class]])
## 
##     GPD    GRDA     LPD    LRDA   Other Seizure 
##       7      15       8       0      24      75